{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Denys88/rl_games/blob/master/notebooks/brax_training.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JbZsMYmyZiVr"
      },
      "outputs": [],
      "source": [
        "!pip install git+https://github.com/Denys88/rl_games"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Mq4PwGLn13Wm"
      },
      "outputs": [],
      "source": [
        "#@title Brax training example\n",
        "#@markdown ## ⚠️ PLEASE NOTE:\n",
        "#@markdown This colab runs using a GPU runtime. From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.\n",
        "\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import os\n",
        "\n",
        "from IPython.display import HTML, clear_output\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "try:\n",
        "  import brax\n",
        "except ImportError:\n",
        "  !pip install git+https://github.com/google/brax.git@main\n",
        "  clear_output()\n",
        "  import brax\n",
        "\n",
        "from brax import envs\n",
        "from brax import jumpy as jp\n",
        "from brax.io import html\n",
        "from brax.io import model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DnUgCbN3OwLB",
        "outputId": "e266fbc5-06ad-4679-e208-c5b485acaadd"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi -L"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6qvHCGgpxrvZ"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "%tensorboard --logdir 'runs/'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5Zym5GEjw_iE"
      },
      "outputs": [],
      "source": [
        "## ant brax config:\n",
        "ant_config = {'params': {'algo': {'name': 'a2c_continuous'},\n",
        "  'config': {'bound_loss_type': 'regularisation',\n",
        "   'bounds_loss_coef': 0.0,\n",
        "   'clip_value': True,\n",
        "   'critic_coef': 4,\n",
        "   'e_clip': 0.2,\n",
        "   'entropy_coef': 0.0,\n",
        "   'env_config': {'env_name': 'ant', 'seed': 5},\n",
        "   'env_name': 'brax',\n",
        "   'gamma': 0.99,\n",
        "   'grad_norm': 1.0,\n",
        "   'horizon_length': 8,\n",
        "   'kl_threshold': 0.008,\n",
        "   'learning_rate': '3e-4',\n",
        "   'lr_schedule': 'adaptive',\n",
        "   'max_epochs': 5000,\n",
        "   'mini_epochs': 4,\n",
        "   'minibatch_size': 32768,\n",
        "   'name': 'ant-brax',\n",
        "   'normalize_advantage': True,\n",
        "   'normalize_input': True,\n",
        "   'normalize_value': True,\n",
        "   'num_actors': 4096,\n",
        "   'player': {'render': True},\n",
        "   'ppo': True,\n",
        "   'reward_shaper': {'scale_value': 0.1},\n",
        "   'schedule_type': 'standard',\n",
        "   'score_to_win': 20000,\n",
        "   'tau': 0.95,\n",
        "   'truncate_grads': True,\n",
        "   'use_smooth_clamp': True,\n",
        "   'value_bootstrap': True},\n",
        "  'model': {'name': 'continuous_a2c_logstd'},\n",
        "  'network': {'mlp': {'activation': 'elu',\n",
        "    'initializer': {'name': 'default'},\n",
        "    'units': [256, 128, 64]},\n",
        "   'name': 'actor_critic',\n",
        "   'separate': False,\n",
        "   'space': {'continuous': {'fixed_sigma': True,\n",
        "     'mu_activation': 'None',\n",
        "     'mu_init': {'name': 'default'},\n",
        "     'sigma_activation': 'None',\n",
        "     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},\n",
        "  'seed': 5}}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dt2q0HgmxDrb"
      },
      "outputs": [],
      "source": [
        "## config from the openai gym mujoco (should have the same network and normalization) to render result:\n",
        "humanoid_config = {'params': {'algo': {'name': 'a2c_continuous'},\n",
        "  'config': {'bound_loss_type': 'regularisation',\n",
        "   'bounds_loss_coef': 0.0,\n",
        "   'clip_value': True,\n",
        "   'critic_coef': 4,\n",
        "   'e_clip': 0.2,\n",
        "   'entropy_coef': 0.0,\n",
        "   'env_config': {'env_name': 'humanoid', 'seed': 5},\n",
        "   'env_name': 'brax',\n",
        "   'gamma': 0.99,\n",
        "   'grad_norm': 1.0,\n",
        "   'horizon_length': 16,\n",
        "   'kl_threshold': 0.008,\n",
        "   'learning_rate': '3e-4',\n",
        "   'lr_schedule': 'adaptive',\n",
        "   'max_epochs': 5000,\n",
        "   'mini_epochs': 5,\n",
        "   'minibatch_size': 32768,\n",
        "   'name': 'humanoid-brax',\n",
        "   'normalize_advantage': True,\n",
        "   'normalize_input': True,\n",
        "   'normalize_value': True,\n",
        "   'num_actors': 4096,\n",
        "   'player': {'render': True},\n",
        "   'ppo': True,\n",
        "   'reward_shaper': {'scale_value': 0.1},\n",
        "   'schedule_type': 'standard',\n",
        "   'score_to_win': 20000,\n",
        "   'tau': 0.95,\n",
        "   'truncate_grads': True,\n",
        "   'use_smooth_clamp': True,\n",
        "   'value_bootstrap': True},\n",
        "  'model': {'name': 'continuous_a2c_logstd'},\n",
        "  'network': {'mlp': {'activation': 'elu',\n",
        "    'initializer': {'name': 'default'},\n",
        "    'units': [512, 256, 128]},\n",
        "   'name': 'actor_critic',\n",
        "   'separate': False,\n",
        "   'space': {'continuous': {'fixed_sigma': True,\n",
        "     'mu_activation': 'None',\n",
        "     'mu_init': {'name': 'default'},\n",
        "     'sigma_activation': 'None',\n",
        "     'sigma_init': {'name': 'const_initializer', 'val': 0}}}},\n",
        "  'seed': 5}}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "swX1oGIQavbI",
        "outputId": "5489fcfb-4426-499d-909e-5b9afa7546dd"
      },
      "outputs": [],
      "source": [
        "import yaml\n",
        "from rl_games.torch_runner import Runner\n",
        "\n",
        "env_name = 'ant'  # @param ['ant', 'humanoid']\n",
        "configs = {\n",
        "    'ant' : ant_config,\n",
        "    'humanoid' : humanoid_config\n",
        "}\n",
        "networks = {\n",
        "    'ant' : 'runs/ant/nn/ant-brax.pth',\n",
        "    'humanoid' : 'runs/humanoid/nn/humanoid-brax.pth'\n",
        "}\n",
        "\n",
        "config = configs[env_name]\n",
        "network_path = networks[env_name]\n",
        "config['params']['config']['full_experiment_name'] = env_name\n",
        "config['params']['config']['max_epochs'] = 1000\n",
        "\n",
        "runner = Runner()\n",
        "runner.load(config)\n",
        "runner.run({\n",
        "    'train': True,\n",
        "})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3s-95B-KlqE1"
      },
      "outputs": [],
      "source": [
        "from rl_games.envs.brax import BraxEnv\n",
        "\n",
        "from IPython.display import HTML, IFrame, display, clear_output\n",
        "import os"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EikAyoGpmDpl",
        "outputId": "dd2006b1-e02a-4907-cf79-1b7d98d657e0"
      },
      "outputs": [],
      "source": [
        "agent = runner.create_player()\n",
        "agent.restore(network_path)\n",
        "\n",
        "env_config = runner.params['config']['env_config']\n",
        "num_actors = 1\n",
        "env = BraxEnv('', num_actors, **env_config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "L6GRBDhsubx6",
        "outputId": "75ff7fb9-d19b-4b3d-91ca-132dc3e557a6"
      },
      "outputs": [],
      "source": [
        "qps = []\n",
        "obs = env.reset()\n",
        "total_reward = 0\n",
        "num_steps = 0\n",
        "\n",
        "class QP:\n",
        "    def __init__(self, qp):\n",
        "        self.pos = jax.numpy.squeeze(qp.pos, axis=0)\n",
        "        self.rot = jax.numpy.squeeze(qp.rot, axis=0)\n",
        "\n",
        "is_done = False\n",
        "while not is_done:\n",
        "    qps.append(QP(env.env._state.qp))\n",
        "    act = agent.get_action(obs)\n",
        "    obs, reward, is_done, info = env.step(act.unsqueeze(0))\n",
        "    total_reward += reward.item()\n",
        "    num_steps += 1\n",
        "\n",
        "print('Total Reward: ', total_reward)\n",
        "print('Num steps: ', num_steps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-iM6S8V55i5a"
      },
      "outputs": [],
      "source": [
        "def visualize(sys, qps):\n",
        "    return HTML(html.render(sys, qps))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 480
        },
        "id": "7xoJWkUI5kbF",
        "outputId": "59be7c1b-2700-44ff-9b90-b5ba1aba80e7"
      },
      "outputs": [],
      "source": [
        "display(visualize(env.env._env.sys, qps))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "include_colab_link": true,
      "name": "rlg_colab_brax.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.7.11 ('rl')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.11"
    },
    "vscode": {
      "interpreter": {
        "hash": "44a9891a118a7be9dddc573b7a9be338decd7d2acd5c055c04ccaf7d7ad0ee03"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
